"""Add created_at to UserModel and ProjectModel

Revision ID: afbc600ff2b2
Revises: c20626d03cfb
Create Date: 2024-10-16 14:31:49.040804

"""

import uuid
from datetime import timedelta

import sqlalchemy as sa
import sqlalchemy_utils
from alembic import op

import dstack._internal.server.models
from dstack._internal.utils.common import get_current_datetime

# revision identifiers, used by Alembic.
revision = "afbc600ff2b2"
down_revision = "c20626d03cfb"
branch_labels = None
depends_on = None


users_table = sa.Table(
    "users",
    sa.MetaData(),
    # partial description - only columns affected by this migration
    sa.Column("id", sqlalchemy_utils.UUIDType(binary=False), primary_key=True, default=uuid.uuid4),
    sa.Column("created_at", dstack._internal.server.models.NaiveDateTime(), nullable=True),
)


projects_table = sa.Table(
    "projects",
    sa.MetaData(),
    # partial description - only columns affected by this migration
    sa.Column("id", sqlalchemy_utils.UUIDType(binary=False), primary_key=True, default=uuid.uuid4),
    sa.Column("created_at", dstack._internal.server.models.NaiveDateTime(), nullable=True),
)


def upgrade() -> None:
    # ### commands auto generated by Alembic - please adjust! ###
    with op.batch_alter_table("projects", schema=None) as batch_op:
        batch_op.add_column(
            sa.Column("created_at", dstack._internal.server.models.NaiveDateTime(), nullable=True)
        )
    with op.batch_alter_table("users", schema=None) as batch_op:
        batch_op.add_column(
            sa.Column("created_at", dstack._internal.server.models.NaiveDateTime(), nullable=True)
        )

    # Set created_at on existing rows.
    # The absolute value does not matter since it cannot be recovered.
    # Just ensure that created_at order matches the insertion order.
    # SELECT should fetch the rows in the insertion order when there are no additional conditions.
    last_created_at = get_current_datetime()

    users_update_params = []
    users = op.get_bind().execute(sa.select(users_table))
    for i, row in enumerate(reversed(users.all())):
        created_at = last_created_at - timedelta(seconds=i)
        users_update_params.append({"_id": row.id, "created_at": created_at})
    update_stmt = (
        users_table.update()
        .where(users_table.c.id == sa.bindparam("_id"))
        .values(created_at=sa.bindparam("created_at"))
    )
    if users_update_params:
        op.get_bind().execute(update_stmt, users_update_params)

    projects_update_params = []
    projects = op.get_bind().execute(sa.select(projects_table))
    for i, row in enumerate(reversed(projects.all())):
        created_at = last_created_at - timedelta(seconds=i)
        projects_update_params.append({"_id": row.id, "created_at": created_at})
    update_stmt = (
        projects_table.update()
        .where(projects_table.c.id == sa.bindparam("_id"))
        .values(created_at=sa.bindparam("created_at"))
    )
    if projects_update_params:
        op.get_bind().execute(update_stmt, projects_update_params)

    with op.batch_alter_table("projects", schema=None) as batch_op:
        batch_op.alter_column("created_at", nullable=False)
    with op.batch_alter_table("users", schema=None) as batch_op:
        batch_op.alter_column("created_at", nullable=False)
    # ### end Alembic commands ###


def downgrade() -> None:
    # ### commands auto generated by Alembic - please adjust! ###
    with op.batch_alter_table("users", schema=None) as batch_op:
        batch_op.drop_column("created_at")

    with op.batch_alter_table("projects", schema=None) as batch_op:
        batch_op.drop_column("created_at")

    # ### end Alembic commands ###
